import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(self, yield_stress: float = 0.22, mu_log: float = 4.0):
        """
        Define trainable continuous physical parameters for differentiable optimization.
        Initialize yield_stress and plastic shear modulus (mu) in log space.

        Args:
            yield_stress (float): yield stress controlling plastic threshold.
            mu_log (float): log shear modulus for plastic correction.
        """
        super().__init__()
        self.yield_stress = nn.Parameter(torch.tensor(yield_stress))  # scalar
        self.mu_log = nn.Parameter(torch.tensor(mu_log))  # scalar

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute corrected deformation gradient from deformation gradient tensor via logarithmic spectral plasticity.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        B = F.shape[0]

        mu = self.mu_log.exp()  # scalar

        # SVD decomposition
        U, sigma, Vh = torch.linalg.svd(F)  # U: (B,3,3), sigma: (B,3), Vh: (B,3,3)

        # Clamp singular values
        sigma = torch.clamp_min(sigma, 1e-6)  # (B,3)

        # Logarithmic principal stretches
        epsilon = torch.log(sigma)  # (B,3)

        # Compute volumetric mean of epsilon
        epsilon_mean = epsilon.mean(dim=1, keepdim=True)  # (B,1)

        # Deviatoric log strain
        epsilon_bar = epsilon - epsilon_mean  # (B,3)

        # Norm of deviatoric strain
        epsilon_bar_norm = torch.linalg.norm(epsilon_bar, dim=1, keepdim=True)  # (B,1)

        # Plastic multiplier
        delta_gamma = epsilon_bar_norm - self.yield_stress / (2 * mu)  # (B,1)

        # Clamp to non-negative
        delta_gamma_clamped = torch.clamp_min(delta_gamma, 0.0)  # (B,1)

        # Avoid division by zero
        denom = epsilon_bar_norm.clamp_min(1e-8)  # (B,1)

        # Compute correction scale factor
        scale = 1.0 - delta_gamma_clamped / denom  # (B,1)

        # No correction if yield condition not surpassed
        scale = torch.where(delta_gamma > 0, scale, torch.ones_like(scale))  # (B,1)

        # Apply correction
        epsilon_bar_corrected = epsilon_bar * scale  # (B,3)

        # Recompose corrected log strain
        epsilon_corrected = epsilon_bar_corrected + epsilon_mean  # (B,3)

        # Inverse log to get corrected singular values
        sigma_corrected = torch.exp(epsilon_corrected)  # (B,3)

        # Reconstructed corrected deformation gradient
        F_corrected = U @ torch.diag_embed(sigma_corrected) @ Vh  # (B,3,3)

        return F_corrected


class ElasticityModel(nn.Module):

    def __init__(self, youngs_modulus_log: float = 12.9, poissons_ratio_sigmoid: float = 0.0):
        """
        Define trainable continuous physical parameters for differentiable optimization.
        Initialize parameters from best prior estimates.

        Args:
            youngs_modulus_log (float): log of Young's modulus.
            poissons_ratio_sigmoid (float): raw Poisson's ratio parameter before sigmoid scaling.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))  # scalar
        self.poissons_ratio_sigmoid = nn.Parameter(torch.tensor(poissons_ratio_sigmoid))  # scalar

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress from corrected deformation gradient tensor using StVK elasticity.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.shape[0]

        # Physical parameters
        youngs_modulus = self.youngs_modulus_log.exp()  # scalar

        # Sigmoid mapping to (0, 0.499) for Poisson's ratio
        poissons_ratio = torch.sigmoid(self.poissons_ratio_sigmoid) * 0.499  # scalar

        # Lamé parameters
        mu = youngs_modulus / (2.0 * (1.0 + poissons_ratio))  # scalar
        la = youngs_modulus * poissons_ratio / ((1.0 + poissons_ratio) * (1.0 - 2.0 * poissons_ratio))  # scalar

        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0)  # (1, 3, 3)

        Ft = F.transpose(1, 2)  # (B, 3, 3)

        # Right Cauchy-Green tensor
        C = torch.matmul(Ft, F)  # (B, 3, 3)

        # Green-Lagrange strain tensor
        E = 0.5 * (C - I)  # (B, 3, 3)

        # Trace of strain tensor
        trE = E.diagonal(dim1=1, dim2=2).sum(dim=1).view(B, 1, 1)  # (B, 1, 1)

        # Second Piola-Kirchhoff stress tensor
        S = 2.0 * mu * E + la * trE * I  # (B, 3, 3)

        # First Piola-Kirchhoff stress tensor
        P = torch.matmul(F, S)  # (B, 3, 3)

        # Kirchhoff stress tensor: tau = P * F^T
        kirchhoff_stress = torch.matmul(P, Ft)  # (B, 3, 3)

        return kirchhoff_stress
